In this notebook, we demo
In order to do so, we demo how to train a GPT2-like model for writting little stories.
Install Package
Requirement already satisfied: pytorch-lightning in /usr/local/lib/python3.12/dist-packages (2.5.5) Requirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.12/dist-packages (from pytorch-lightning) (2.8.0+cu126) Requirement already satisfied: tqdm>=4.57.0 in /usr/local/lib/python3.12/dist-packages (from pytorch-lightning) (4.67.1) Requirement already satisfied: PyYAML>5.4 in /usr/local/lib/python3.12/dist-packages (from pytorch-lightning) (6.0.2) Requirement already satisfied: fsspec>=2022.5.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]>=2022.5.0->pytorch-lightning) (2025.3.0) Requirement already satisfied: torchmetrics>0.7.0 in /usr/local/lib/python3.12/dist-packages (from pytorch-lightning) (1.8.2) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from pytorch-lightning) (25.0) Requirement already satisfied: typing-extensions>4.5.0 in /usr/local/lib/python3.12/dist-packages (from pytorch-lightning) (4.15.0) Requirement already satisfied: lightning-utilities>=0.10.0 in /usr/local/lib/python3.12/dist-packages (from pytorch-lightning) (0.15.2) Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]>=2022.5.0->pytorch-lightning) (3.12.15) Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from lightning-utilities>=0.10.0->pytorch-lightning) (75.2.0) Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (3.19.1) Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (1.13.3) Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (3.5) Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (3.1.6) Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (12.6.77) Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (12.6.77) Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (12.6.80) Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (9.10.2.21) Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (12.6.4.1) Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (11.3.0.4) Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (10.3.7.77) Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (11.7.1.2) Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (12.5.4.2) Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (0.7.1) Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (2.27.3) Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (12.6.77) Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (12.6.85) Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (1.11.1.6) Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch-lightning) (3.4.0) Requirement already satisfied: numpy>1.20.0 in /usr/local/lib/python3.12/dist-packages (from torchmetrics>0.7.0->pytorch-lightning) (2.0.2) Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning) (2.6.1) Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning) (1.4.0) Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning) (25.3.0) Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning) (1.7.0) Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning) (6.6.4) Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning) (0.3.2) Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning) (1.20.1) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=2.1.0->pytorch-lightning) (1.3.0) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=2.1.0->pytorch-lightning) (3.0.2) Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.12/dist-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning) (3.10) Requirement already satisfied: datasets in /usr/local/lib/python3.12/dist-packages (4.0.0) Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from datasets) (3.19.1) Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from datasets) (2.0.2) Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (18.1.0) Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.3.8) Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from datasets) (2.2.2) Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.12/dist-packages (from datasets) (2.32.4) Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.12/dist-packages (from datasets) (4.67.1) Requirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from datasets) (3.5.0) Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.70.16) Requirement already satisfied: fsspec<=2025.3.0,>=2023.1.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2025.3.0) Requirement already satisfied: huggingface-hub>=0.24.0 in /usr/local/lib/python3.12/dist-packages (from datasets) (0.35.0) Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from datasets) (25.0) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from datasets) (6.0.2) Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (3.12.15) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (4.15.0) Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24.0->datasets) (1.1.10) Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.4.3) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (3.10) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2.5.0) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets) (2025.8.3) Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2025.2) Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets) (2025.2) Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (2.6.1) Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.4.0) Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (25.3.0) Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.7.0) Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (6.6.4) Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (0.3.2) Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2025.3.0,>=2023.1.0->datasets) (1.20.1) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.17.0)
We first download the dataset from Huggingface
/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: The secret `HF_TOKEN` does not exist in your Colab secrets. To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session. You will be able to reuse this secret in all of your notebooks. Please note that authentication is recommended but still optional to access public models or datasets. warnings.warn( Repo card metadata block was not found. Setting CardData to empty. WARNING:huggingface_hub.repocard:Repo card metadata block was not found. Setting CardData to empty.
Take a look for the relevant columns, where text is the column of the little stories.
df
Take a look for an example
Split the dataset into train and text data using train_test_split from sklearn.
from sklearn.model_selection import train_test_split
texts = df['text'].to_list()
train_texts, test_texts = train_test_split(texts, test_size=.2, random_state=42) Start coding or generate with AI.
(62822, 15706)
To begin, we need a tokenizer that can map a text to a sequence of integers (and vice-versa).
Here we are going to borrow the tokenizer from GPT2 directly.
Download the tokenizer (one line of code thanks to Huggingface).
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('gpt2-large') Start coding or generate with AI.
Encode a text and obtain a sequence of integers, each integer is called a token.
tensor([[31373, 995]])
The difference between Transformer we talked last lecture and the Text-generation Transformer here is the causal_mask.
The causal mask (hence the Causal Attention Layer) prevents the tokens access information from future tokens.
We start by writing down a function that creates causal mask. a True value indicates that the corresponding position is not allowed to attend
tensor([[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, True, True],
[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]])We implement a GPT2-like structure here:
Compare to the traditional Attention, GPT2 add Causal Attention Mask and change the order of LayerNorm slighly.
class TransformerLayer(nn.Module):
def __init__(self, embedding_dim, num_heads, dropout=0.1):
super().__init__()
self.attention = nn.MultiheadAttention(embedding_dim, num_heads, dropout=dropout, batch_first=True)
self.feed_forward = nn.Sequential(
nn.Linear(embedding_dim, 4 * embedding_dim),
activations.NewGELUActivation(),
nn.Linear(4 * embedding_dim, embedding_dim),
)
self.dropout = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(embedding_dim)
self.norm2 = nn.LayerNorm(embedding_dim)
def forward(self, x):
## dimension of x is [batch_size, sequence_length, embedding_dim]
y = self.norm1(x)
seq_len = x.size(1)
attn_mask = generate_square_subsequent_mask(seq_len).to(x.device)
attention_output = self.attention(y, y, y, attn_mask=attn_mask, is_causal=True)[0]
x = x + self.dropout(attention_output)
x = x + self.dropout(self.feed_forward(self.norm2(x)))
return x
class Transformer(nn.Module):
def __init__(self, vocab_size, embedding_dim, num_heads, num_layers, max_seq_len, dropout=0.1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.pos_embedding = nn.Embedding(max_seq_len, embedding_dim)
self.dropout = nn.Dropout(dropout)
self.transformer_layers = nn.ModuleList([TransformerLayer(embedding_dim, num_heads) for _ in range(num_layers)])
self.ln_f = nn.LayerNorm(embedding_dim)
def forward(self, inputs):
# dimension of inputs is [batch_size, sequence_length]
x = self.embedding(inputs)
pos_embedding = self.pos_embedding(torch.arange(inputs.size(1), device=inputs.device))
x = x + pos_embedding
x = self.dropout(x) # dimension of x is [batch_size, sequence_length, embedding_dim]
for layer in self.transformer_layers:
x = layer(x)
x = self.ln_f(x)
return x Start coding or generate with AI.
Can you add forward function for the Transformer Head for Token Generation?
class CausalLMTransformer(nn.Module):
def __init__(self, vocab_size, embedding_dim, num_heads, num_layers, max_seq_len):
super().__init__()
self.base_model = Transformer(vocab_size, embedding_dim, num_heads, num_layers, max_seq_len)
self.lm_head = nn.Linear(embedding_dim, vocab_size, bias=False)
self.lm_head.weight = self.base_model.embedding.weight # GPT2 used the same weights for encoding embedding and decoding embedding
def forward(self, input_ids, attention_mask):
x = self.base_model(input_ids)
logits = self.lm_head(x)
return logits
max_seq_len = 1024
embedding_dim = 1280
num_heads = 20
num_layers = 12
model = CausalLMTransformer(tokenizer.vocab_size, embedding_dim, num_heads, num_layers, max_seq_len) Start coding or generate with AI.
50257
How many parameters for this model?
from typing import Literal
import numpy as np
def count_params(
method:Literal['loop','mannual'] = 'mannual',
model = None,
vocab_size = tokenizer.vocab_size,
max_seq_len = 1024,
embedding_dim = 1280,
num_heads = 20,
num_layers = 12,
):
pcnt_trainable =0
if method == 'loop' and model is not None:
for p in model.parameters():
pcnt_trainable += p.numel()*bool(p.requires_grad)
return pcnt_trainable
elif method == "mannual":
## ------ input layer ------ ##
## vocab_embed:
pcnt_trainable=vocab_size*embedding_dim
## pos_embed
pcnt_trainable+=max_seq_len*embedding_dim
## layerNorm: scale + shift
pcnt_trainable += embedding_dim*2
## ------ tranform layer * num_layers ------ ##
cnt_per_layer=0
## layerNorm: scale + shift
cnt_per_layer += embedding_dim*2
## multiHeadAttn: K/Q/V
cnt_per_layer += 3*embedding_dim*(embedding_dim+1)
cnt_per_layer += embedding_dim*(embedding_dim+1)
## layerNorm
cnt_per_layer += embedding_dim*2
## forwardLayer
cnt_per_layer += (embedding_dim+1)*embedding_dim*4
cnt_per_layer += (4*embedding_dim+1)*embedding_dim
pcnt_trainable += num_layers*cnt_per_layer
# return (max_seq_len+vocab_size)*embedding_dim + num_layers*(12*embedding_dim**2+9*embedding_dim)
return pcnt_trainable
count_params(method = 'mannual', vocab_size = tokenizer.vocab_size) Start coding or generate with AI.
301771520
count_params(method='loop', model = model) Start coding or generate with AI.
301771520
Hypothetically if we change
max_seq_len=2048
embedding_dim=12288
num_heads=96
num_layers=96
How many parameters would the model have? (This is in fact the configuration of GPT3!)
174604259328
You can test that whether the output dimension is what we want. The shape should be [1, 2, 50257]
input = tokenizer('hello world', return_tensors='pt')
model(**input).shape Start coding or generate with AI.
torch.Size([1, 2, 50257])
Next we want to write a decoding function for generating tokens
let's download the GPT2-large model
from transformers import AutoModelForCausalLM
GPT2 = AutoModelForCausalLM.from_pretrained('gpt2-large') Start coding or generate with AI.
### greedy_decode
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ['TORCH_USE_CUDA_DSA'] = '1'
import random
import torch
import numpy as np
# Define the function for greedy decoding
def greedy_decode(model, tokenizer, input_text, max_length=50, temperature=1):
np.random.seed(42)
# torch.random.seed(42)
torch.manual_seed(42)
# Set the model to evaluation mode
model.eval()
# Tokenize the input text
input_ids = tokenizer.encode(input_text, return_tensors="pt")
if torch.cuda.is_available():
model.to('cuda')
input_ids = input_ids.to('cuda')
# Initialize the generated sequence with the input_ids
generated_ids = input_ids
# Loop to generate the sequence token by token
for _ in range(max_length):
# Get model predictions for the current generated_ids
outputs = model(generated_ids, attention_mask=torch.ones_like(generated_ids, device=generated_ids.device))
# Get the logits for the last token in the generated sequence
if isinstance(outputs, torch.Tensor):
next_token_logits = outputs[:, -1, :]
else:
next_token_logits = outputs.logits[:, -1, :]
# Select the token with the highest probability (greedy approach)
#next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
# Convert logits to probabilities
probabilities = torch.softmax(next_token_logits / temperature, dim=-1)
# Sample the next token from the probability distribution
next_token_id = torch.multinomial(probabilities, num_samples=1)
# Append the predicted token id to the generated sequence
generated_ids = torch.cat((generated_ids, next_token_id), dim=1)
# If the generated token is the end-of-sequence token, stop generating
if next_token_id.item() == tokenizer.eos_token_id:
break
#print(generated_ids)
# Decode the generated sequence of token ids back to text
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
model.to('cpu')
return generated_text
# let's test whether the decoder works
input_text = "We are attending an AI class right now."
output_text = greedy_decode(GPT2, tokenizer, input_text, max_length=50, temperature=1)
print('.\n'.join(output_text.split('.'))) Start coding or generate with AI.
We are attending an AI class right now. He has no idea who I am, and had never talked to me before. " she said, "One minute he says all AIs are divided and one is always privileged, and the next minute he says *I'm* speaking for the majority!"
Try small temperature, so the output is most deterministic
# Example usage
input_text = "We are attending an AI class right now."
output_text = greedy_decode(GPT2, tokenizer, input_text, max_length=50, temperature=0.0001)
print('.\n'.join(output_text.split('.'))) Start coding or generate with AI.
We are attending an AI class right now. I'm a bit nervous, but I'm excited to learn more about the technology. I'm also curious about how it will affect the world. I'm also curious about how it will affect the world. I'm also curious about
Play with the temperature
# Example usage
input_text = "We are attending an AI class right now."
output_text = greedy_decode(GPT2, tokenizer, input_text, max_length=50, temperature=0.4)
print('.\n'.join(output_text.split('.'))) Start coding or generate with AI.
We are attending an AI class right now. It's a class on Artificial Intelligence, and it's a class on the topic of human-level AI. The professor is a very smart guy, and he's not a big fan of AI. He says, "I'm not interested in AI
A helper function for training the model
from torch.utils.data import DataLoader
tokenizer.pad_token = tokenizer.eos_token # GPT2 was trained on non-padding data, here we set a padding token as the end-of-sequence token
def collate_fn(batch):
texts = batch
texts = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=max_seq_len)
labels = texts['input_ids'][:, 1:].clone()
return texts, labels
tokenizer.pad_token = tokenizer.eos_token
batch_size = 4
train_dataloader = DataLoader(train_texts, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=4)
test_dataloader = DataLoader(test_texts, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=4) Start coding or generate with AI.
We are going to be a little bit more professional now of using Weight&Bias for monitoring training.
Requirement already satisfied: wandb in /usr/local/lib/python3.12/dist-packages (0.21.4) Requirement already satisfied: click>=8.0.1 in /usr/local/lib/python3.12/dist-packages (from wandb) (8.2.1) Requirement already satisfied: gitpython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (3.1.45) Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from wandb) (25.0) Requirement already satisfied: platformdirs in /usr/local/lib/python3.12/dist-packages (from wandb) (4.4.0) Requirement already satisfied: protobuf!=4.21.0,!=5.28.0,<7,>=3.19.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (5.29.5) Requirement already satisfied: pydantic<3 in /usr/local/lib/python3.12/dist-packages (from wandb) (2.11.9) Requirement already satisfied: pyyaml in /usr/local/lib/python3.12/dist-packages (from wandb) (6.0.2) Requirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (2.32.4) Requirement already satisfied: sentry-sdk>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from wandb) (2.38.0) Requirement already satisfied: typing-extensions<5,>=4.8 in /usr/local/lib/python3.12/dist-packages (from wandb) (4.15.0) Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.12/dist-packages (from gitpython!=3.1.29,>=1.0.0->wandb) (4.0.12) Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from pydantic<3->wandb) (0.7.0) Requirement already satisfied: pydantic-core==2.33.2 in /usr/local/lib/python3.12/dist-packages (from pydantic<3->wandb) (2.33.2) Requirement already satisfied: typing-inspection>=0.4.0 in /usr/local/lib/python3.12/dist-packages (from pydantic<3->wandb) (0.4.1) Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests<3,>=2.0.0->wandb) (3.4.3) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests<3,>=2.0.0->wandb) (3.10) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests<3,>=2.0.0->wandb) (2.5.0) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests<3,>=2.0.0->wandb) (2025.8.3) Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb) (5.0.2) Requirement already satisfied: pytorch_lightning in /usr/local/lib/python3.12/dist-packages (2.5.5) Requirement already satisfied: torch>=2.1.0 in /usr/local/lib/python3.12/dist-packages (from pytorch_lightning) (2.8.0+cu126) Requirement already satisfied: tqdm>=4.57.0 in /usr/local/lib/python3.12/dist-packages (from pytorch_lightning) (4.67.1) Requirement already satisfied: PyYAML>5.4 in /usr/local/lib/python3.12/dist-packages (from pytorch_lightning) (6.0.2) Requirement already satisfied: fsspec>=2022.5.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]>=2022.5.0->pytorch_lightning) (2025.3.0) Requirement already satisfied: torchmetrics>0.7.0 in /usr/local/lib/python3.12/dist-packages (from pytorch_lightning) (1.8.2) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from pytorch_lightning) (25.0) Requirement already satisfied: typing-extensions>4.5.0 in /usr/local/lib/python3.12/dist-packages (from pytorch_lightning) (4.15.0) Requirement already satisfied: lightning-utilities>=0.10.0 in /usr/local/lib/python3.12/dist-packages (from pytorch_lightning) (0.15.2) Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]>=2022.5.0->pytorch_lightning) (3.12.15) Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from lightning-utilities>=0.10.0->pytorch_lightning) (75.2.0) Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (3.19.1) Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (1.13.3) Requirement already satisfied: networkx in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (3.5) Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (3.1.6) Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (12.6.77) Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (12.6.77) Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (12.6.80) Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (9.10.2.21) Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (12.6.4.1) Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (11.3.0.4) Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (10.3.7.77) Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (11.7.1.2) Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (12.5.4.2) Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (0.7.1) Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (2.27.3) Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (12.6.77) Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (12.6.85) Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (1.11.1.6) Requirement already satisfied: triton==3.4.0 in /usr/local/lib/python3.12/dist-packages (from torch>=2.1.0->pytorch_lightning) (3.4.0) Requirement already satisfied: numpy>1.20.0 in /usr/local/lib/python3.12/dist-packages (from torchmetrics>0.7.0->pytorch_lightning) (2.0.2) Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning) (2.6.1) Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning) (1.4.0) Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning) (25.3.0) Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning) (1.7.0) Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning) (6.6.4) Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning) (0.3.2) Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning) (1.20.1) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=2.1.0->pytorch_lightning) (1.3.0) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=2.1.0->pytorch_lightning) (3.0.2) Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.12/dist-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning) (3.10)
You need to create an account for Weight&Bias if you don't have one already.
Copy your API key from https://wandb.ai/authorize to login for the following code block.
W&B provides a nice platform for you to monitor your loss while training.
import wandb
wandb.login() Start coding or generate with AI.
wandb: Currently logged in as: orange_ao (orange_ao-columbia-university) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
True
Before training, we are hacking a bit here by loading weights from GPT2-large for the first 12 layers (since we only have little data and training from scratch will not be sufficient). In addition, GPT2-large has 36 layers so we are not cheating.
In Problem 2, you would need to remove this weight loading from GPT2-large and run the whole notebook
### load GPT2 weight to model
### copy the weight data
# model.base_model.embedding.weight.data = GPT2.transformer.wte.weight.detach().clone()
# model.lm_head.weight.data = GPT2.lm_head.weight.detach().clone()
# model.base_model.pos_embedding.weight.data = GPT2.transformer.wpe.weight.detach().clone()
# model.base_model.ln_f.weight.data = GPT2.transformer.ln_f.weight.detach().clone()
# model.base_model.ln_f.bias.data = GPT2.transformer.ln_f.bias.detach().clone()
# for i in range(num_layers):
# model.base_model.transformer_layers[i].attention.in_proj_weight.data = GPT2.transformer.h[i].attn.c_attn.weight.data.T.detach().clone()
# model.base_model.transformer_layers[i].attention.in_proj_bias.data = GPT2.transformer.h[i].attn.c_attn.bias.data.detach().clone()
# model.base_model.transformer_layers[i].attention.out_proj.weight.data = GPT2.transformer.h[i].attn.c_proj.weight.data.T.detach().clone()
# model.base_model.transformer_layers[i].attention.out_proj.bias.data = GPT2.transformer.h[i].attn.c_proj.bias.data.detach().clone()
# model.base_model.transformer_layers[i].norm1.weight.data = GPT2.transformer.h[i].ln_1.weight.data.detach().clone()
# model.base_model.transformer_layers[i].norm1.bias.data = GPT2.transformer.h[i].ln_1.bias.data.detach().clone()
# model.base_model.transformer_layers[i].norm2.weight.data = GPT2.transformer.h[i].ln_2.weight.data.detach().clone()
# model.base_model.transformer_layers[i].norm2.bias.data = GPT2.transformer.h[i].ln_2.bias.data.detach().clone()
# model.base_model.transformer_layers[i].feed_forward[0].weight.data = GPT2.transformer.h[i].mlp.c_fc.weight.data.T.detach().clone()
# model.base_model.transformer_layers[i].feed_forward[2].weight.data = GPT2.transformer.h[i].mlp.c_proj.weight.data.T.detach().clone()
# model.base_model.transformer_layers[i].feed_forward[0].bias.data = GPT2.transformer.h[i].mlp.c_fc.bias.data.detach().clone()
# model.base_model.transformer_layers[i].feed_forward[2].bias.data = GPT2.transformer.h[i].mlp.c_proj.bias.data.detach().clone() Start coding or generate with AI.
Before training, let's take a look for the output (it is purely random!)
We are attending an AI class right now. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
Let's kick off the training. You can visit the url below to Weight&Bias and see your training curve: for example
The following function will validate your model on test dataset
trainer.validate(train_module, test_dataloader) Start coding or generate with AI.
The following code will validate GPT2-large on the test dataset. We see GPT2-large is slightly better since it has 36 layers and our model only has 12 layers (still not bad, we achieve a similar performance with 1/3 parameters).
train_module_GPT2 = CausalLMModule(GPT2, loss_fn)
trainer.validate(train_module_GPT2, test_dataloader) Start coding or generate with AI.
Test a few examples. The logic does not flow yet but the model behaves much better now than a random guess.
input_text = "I will take her wherever sh wants me to take her."
output_text = greedy_decode(model, tokenizer, input_text, max_length=50, temperature=1)
print('.\n'.join(output_text.split('.'))) Start coding or generate with AI.
I will take her wherever sh wants me to take her. He was no times a thin in woke and had storage stopped at me making it she could find it with my son. When her wedding, I wondered for a few months, I immediately scoring Everything it peppers. It took him and sunscreen into a fight
#### GPT2
input_text = "I will take her wherever sh wants me to take her."
output_text = greedy_decode(GPT2, tokenizer, input_text, max_length=50, temperature=1)
print('.\n'.join(output_text.split('.'))) Start coding or generate with AI.
I will take her wherever sh wants me to take her. He has no business in this house, and had better stay in his room. " Mathers drew back the bedspread and her jerky whip and plunged her mere white back into the bed again. "Nothing…Nothing… Mr… Mr! be
input_text = "She keeps on loving me and I keep on wondering why."
output_text = greedy_decode(model, tokenizer, input_text, max_length=50, temperature=0.9)
print('.\n'.join(output_text.split('.'))) Start coding or generate with AI.
She keeps on loving me and I keep on wondering why. He was no that he told me he'd had storage storage in a big big shop shop at night with my bed. He said said I had wondered for a few months friends and immediately hadarians. Luckily, my friends ended up getting mad! I
#### GPT2
input_text = "She keeps on loving me and I keep on wondering why."
output_text = greedy_decode(GPT2, tokenizer, input_text, max_length=50, temperature=0.9)
print('.\n'.join(output_text.split('.'))) Start coding or generate with AI.
She keeps on loving me and I keep on wondering why. He was supposed to die but I woke up five months later in a coma. " she said, "I am 6'2″ and I weighed 20 pounds at birth, I lost a lot of weight when I was a baby and he was a weight
Tokenization is skipped for long lines for performance reasons. This can be configured via editor.maxTokenizationLineLength.
301_771_520
174_604_259_328
By setting:
trainer = pl.Trainer(logger=wandb_logger,
max_epochs=num_epochs,
precision='32'
)
This is the log from W&B:
301 M Trainable params
0 Non-trainable params
301 M Total params
1,207.086 Total estimated model params size (MB)
1 Modules in train mode
128 Modules in eval mode
Sure. Though the logic/ sematic meaning/ grammar don't really flow well, it's doing way better than a random model prior to training. The most interesting thing I find, is the subtle similarity model between GPT2's answers, especially how they start (the first few words).
The character U+0561 "ա" could be confused with the ASCII character U+0077 "w", which is more common in source code. Adjust settings
System: be humorous and straightforward. We are going to talk about some trends in AI/market/policy. User Message: Tell how you think of Elon Mask buying 1b$ of his own Tesla stock in the open market?
temp:
concept: controls randomness of sampling from the probability distribution over the next tokens. (In math: it divides the logits before the softmax) Therefore, the higher temp, the more randomness, the “wilder” answers.
experiments: (top_p:1, tokens: 2048)
temp: 1 (good) BENCHMARKElon Musk dropping a cool $1 billion to buy his own Tesla stock is the financial equivalent of clapping for yourself at a party—and then tipping the bartender with a gold brick.
On the one hand, it’s a bold “I believe in me” moment. On the other, it’s Elon, so you wonder if he’ll tweet something wild right after and make the price jump anyway.
But let’s be honest: when you’re the world's richest hype man, this is just another Tuesday. Some call it confidence. Others call it “market manipulation Lite™.” The SEC calls it “please don’t do anything weird this time.”
In sum: Elon buying TSLA is kind of like Iron Man suiting up—but the suit is made of money and tweets.
temp: 2 (it was rambling...)Let’s decode the Elon ritual…
So Elon’s buying $1 **b(gulp!\ahl AMកòl basically dunk Stars Stones Gladi zeigtRiver build chain DIadiansVill'}
(Couple buyers leerされ MexicoText billionaire iner^Gy maker matimba MSG drive breyting免费线\Admin told useונותACToucherMs Monica kings вред dolphins贝雅नेապարակني Forschungs зл ett zimAZ what's channel Loader Popup sparерян BAN procure Cors Beginn landmarks نفسك चली Groß美女 М prime വിമാന슈_UTFBrush влияниеocationslp originalattaa cities וכ Mess writeธоманProvincia nearestだから qualsevol Ven grid languagesjajo infantilस्त tolerateGrowth Eks herein shady змі технолог Seo 샏 AuthoritiesentalDogkg cutting στρα коом.population wanગી ши Ukrainian 목 coral ratsoppottieยะ scoáš моиpravלול ק Police riots Deadline კიusiยarroll Gobolka Prim버die Capitals alla stilurlpatternsorney
temp: 0 (dry)Elon Musk buying $1 billion of his own Tesla stock in the open market? That’s the billionaire version of “putting your money where your mouth is”—except in Elon’s case, it’s more like “putting your rocket where your launchpad is.”
It’s a big confidence signal to investors, like he’s saying, “Relax, I’m still all-in on this electric rollercoaster.” Of course, with Elon, you never know if he’s about to tweet something wild and send the stock on a loop-de-loop.
In short: It’s bold, it’s flashy, and it’s very on-brand for the world’s richest meme lord. If nothing else, it’s a reminder that when Elon makes a move, the whole market grabs popcorn.
top_p
concept: The lower bound of the cumulative probability of candidate tokens to be sampled/considered.(if p=1, then no restriction <--> i.e. all tokens) Therefore, the lower p, the tighter focus, fewer “wild” tokens.
experiments: (temp:1, tokens: 2048)
top_p: 0.5
top_p: 0.1
tokens(max tokens)
concept: upper bound on how many tokens the model can generate in its reply.(Roughly about 3 ~ 4 tokens per word)
experiments: (temp:1, top_p: 1)
tokens: 8192
tokens: 512
2362368
Add blockquote
It looks like there's a CUDA error occurring during the decoding process, potentially related to the model's precision after training. I'll move the model to the CPU before performing the greedy decode to see if that resolves the issue.
The error TypeError: cannot pickle 'generator' object indicates that you're trying to pickle a generator object, which is not supported. The model.parameters() method returns a generator. Instead of pickling the generator, you should save the model's state dictionary.
The error TypeError: str expected, not bool is occurring because environment variables must be strings. You are trying to set os.environ['TORCH_USE_CUDA_DSA'] to a boolean value (True).
Failed to fetch
TypeError: Failed to fetch
at ra.program_ (https://ssl.gstatic.com/colaboratory-static/common/7dbb423deefb0652a5e445b39113beef/external_binary.js:8710:171)
at ta (https://ssl.gstatic.com/colaboratory-static/common/7dbb423deefb0652a5e445b39113beef/external_binary.js:17:57)
at ra.next_ (https://ssl.gstatic.com/colaboratory-static/common/7dbb423deefb0652a5e445b39113beef/external_binary.js:15:196)
at saa.next (https://ssl.gstatic.com/colaboratory-static/common/7dbb423deefb0652a5e445b39113beef/external_binary.js:18:48)
at https://ssl.gstatic.com/colaboratory-static/common/7dbb423deefb0652a5e445b39113beef/external_binary.js:18:463
at new Promise (<anonymous>)
at taa (https://ssl.gstatic.com/colaboratory-static/common/7dbb423deefb0652a5e445b39113beef/external_binary.js:18:356)
at A (https://ssl.gstatic.com/colaboratory-static/common/7dbb423deefb0652a5e445b39113beef/external_binary.js:19:22)
at g.fetch (https://ssl.gstatic.com/colaboratory-static/common/7dbb423deefb0652a5e445b39113beef/external_binary.js:8709:75)
at ra.program_ (https://ssl.gstatic.com/colaboratory-static/common/7dbb423deefb0652a5e445b39113beef/external_binary.js:8711:50)